package com.yjy.mdc.filter;

import org.slf4j.MDC;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.AbstractRequestLoggingFilter;

import javax.servlet.http.HttpServletRequest;
import java.util.UUID;

@Component
public class TraceIdLogFilter extends AbstractRequestLoggingFilter {
    public static final String TRACE_ID = "traceId";

    @Override
    protected void beforeRequest(HttpServletRequest request, String message) {
        // 如果有上层调用就用上层的ID
        String traceId = request.getHeader(TRACE_ID);
        if (traceId == null) {
            traceId = UUID.randomUUID().toString().replace("-", "");
        }

        MDC.put(TRACE_ID, traceId);
    }

    @Override
    protected void afterRequest(HttpServletRequest request, String message) {
        // 调用结束后删除
        MDC.remove(TRACE_ID);
    }
}